import logging
import os
import random
import numpy as np
import torch

from tools.logger import EpochLogger


def preprocess(obs,args):
    obs2= obs / 255.0 if args.image or args.env_type == "ram" else obs
    obs2 = obs2.cuda() if args.device != "cpu" else obs2
    # obs2 = obs.to(device = "cuda" if CUDA else "cpu",dtype=torch.float)  / 255.0
    return obs2


def size_action_space(action_space):
    return action_space.n if action_space.__class__.__name__ == "Discrete" else action_space.shape[0]

def dim_action_space(action_space):
    return 1 if action_space.__class__.__name__ == 'Discrete' else action_space.shape[0]

def dtype_action_space(action_space):
    return torch.long if action_space.__class__.__name__ == 'Discrete' else torch.float32

def prepare_cuda(args):
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    np.set_printoptions(linewidth=np.nan,precision=2)
    torch.set_printoptions(precision=3,linewidth=150)

    if args.device != "cpu":
        torch.cuda.set_device("cuda:0")
        torch.cuda.init()
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True


def create_directories(save_dir,seed,params):

    # if save_dir is not None:
    dir_models = save_dir + "/models/"
    try:
        os.makedirs(save_dir)  # create directory for log files
        os.makedirs(dir_models)  # create directory for log files
    except:
        raise
    epoch_logger = EpochLogger(output_dir=save_dir,exp_name="Seed-"+str(seed))
    coord_epoch_logger = EpochLogger(output_dir=save_dir,exp_name="Coord-Seed-"+str(seed),output_fname="coord_progress.txt")
    # if params.env_type == "maze":
    eval_logger = EpochLogger(output_dir=save_dir,exp_name="Eval-Seed-"+str(seed),output_fname="eval_progress.txt")
    if params.env_type == "multiworld" or params.env_type == "mujoco" or params.env_type == "maze" :
        setup_logger("distances", save_dir, "/distances.log", out=False,formatter="")
    if params.plan_interval != -1 and params.env_type == "maze":
        setup_logger("plans", save_dir, "/plans.log", out=False,formatter="")
    epoch_logger.save_config(params)
    setup_logger("reject", save_dir, "/rejects.log", out=False)
    return epoch_logger,coord_epoch_logger,dir_models,eval_logger



def setup_logger(name,save_dir=None, log_file="/logs.log", level=logging.INFO,out=True,formatter=None):
    """Function setup as many loggers as you want"""
#(str(time.time()) if save_dir is None else "")
    logger = logging.getLogger(name)
    logger.handlers=[]
    formatter = logging.Formatter('%(asctime)s -- %(levelname)s -- %(message)s') if formatter is None else logging.Formatter(formatter)
    if out:
        handler = logging.StreamHandler(sys.stdout)
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    if save_dir is not None :
        log_file = save_dir + log_file
        filehandler = logging.FileHandler(log_file)
        filehandler.setFormatter(formatter)
        logger.addHandler(filehandler)

    logger.setLevel(level)

    return logger

def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau
        )

def hot_encoding_index(shape,index,device="cpu",dim=2,action_space=None):
    if action_space is not None and action_space.__class__.__name__ != 'Discrete':
        return index
    coding = torch.zeros((index.shape[0],*shape),dtype=torch.float,device=device)
    coding.scatter_(dim,index,1.)
    return coding
